Linear regression has become widely known as a backbone of modern statistics. Even as more complex, “black box”-style machine learning techniques increase in popularity, many statisticians and researchers still fall back on regression for its interpretability and simpleness. However, linear regression relies on a number on assumptions that may not always be true in practice, such as the constant, monotonic linearity of predictor variables in relation to the response. In this guide, we explore the use of splines to help model predictor variables that may have changing relationships across their domain. These techniques help us to match the predictive power seen in some more advanced machine learning algorithms while keeping the benefits gained by using regression. We show examples in three popular statistical modelling languages - python, R, and STATA.
In this guide, we will be using the Wage dataset from the R package ISLR. This data is also used in the book Introduction to Statistical Learning. This dataset contains wages from 3,000 Mid-Atlantic, male workers, between the years 2003-2009, along with a select number of other personal demographics. We retain the variables for wage, age, year, and education for our analysis. Our goal is to examine the relationship between age, year, and education and workers’ yearly wage.
Below is a table and plots explaining the variables used in this example.
| Variables | Role | Type | Explanation |
|---|---|---|---|
Wage |
Response | Numerical | Worker’s Wage |
Age |
Predictor | Numerical | Age of the worker |
Year |
Predictor | Numerical | Year when the data was collected |
Education |
Predictor | Categorical | The education of that worker |
Table 0.1: Table explain the type, role, along with the explanation for each variable.
Figure 0.2: Histogram or Bar chart for variables used in the analysis
We will first calculate a simple linear regression as a baseline. We will then implement four different spline-like techniques on the “age” predictor variable : a step function, polynomial regression, basis spline, and natural spline. At each step, we will check for fit quality, noting any potential improvements along the way. We will conclude with a retrospective and summary of what we learned.
#!/usr/bin/env python
# coding: utf-8
# In[1]:
#Packages required
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')
#%matplotlib inline
import statsmodels.api as sm
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from patsy import dmatrix
# In[2]:
#Let's read in the data file
data = pd.read_csv("/Users/kwschulz/STATS506/Stats506_Project/Dataset/data.csv")
# In[3]:
#Take a quick glance at what the data looks like
data.head()
# In[4]:
#Let's check to see if we have any missing values
data.isna().sum()
# In[5]:
#Filter to variables for our analysis
data = data[["wage", "age", "education", "year"]]
# In[6]:
#Map education to ordinal scale
education_map = {"1. < HS Grad":1,"2. HS Grad":2,
"3. Some College":3, "4. College Grad":4,
"5. Advanced Degree":5}
data['education'] = data.education.map(education_map)
# In[7]:
#Lets check the distribution of our predictors
data[["wage", "age"]].hist(layout=(2,1), figsize=(15,15))
plt.show()
plt.savefig('hist.png')
# In[8]:
#checking year distribution
data.year.value_counts().reindex([2003, 2004, 2005, 2006, 2007, 2008, 2009]).plot(kind='bar',
title='year',
ylabel='count',
figsize=(7.5,7.5))
plt.savefig('year_bar.png')
# In[9]:
#checking education distribution
data.education.value_counts().reindex([1, 2, 3, 4, 5]).plot(kind='bar',
title='education',
ylabel='count',
figsize=(7.5,7.5))
plt.savefig('education_bar.png')
# In[10]:
#linear regression model
model = sm.OLS(data["wage"], sm.add_constant(data.drop('wage',axis=1))).fit()
# In[11]:
#let's check how it did
model.summary()
# In[12]:
#let's cut age into 6 bins - stepwise
data["age_cut"] = pd.cut(data.age, bins=6, labels=False)
# In[13]:
#now let's model age with bins
model2 = sm.OLS(data["wage"], sm.add_constant(data.drop(['wage','age'],axis=1))).fit()
# In[14]:
#model 2 summary
model2.summary()
# In[15]:
#let's check out the scatter plot of age v wage
data.plot(x="age", y="wage", kind='scatter', figsize=(7.5,7.5))
# In[16]:
#2nd degree polynomial
p = np.poly1d(np.polyfit(data["age"], data["wage"], 2))
t = np.linspace(0, 80, 200)
plt.plot(data["age"], data["wage"], 'o', t, p(t), '-')
rs = sm.OLS(data["wage"],
np.column_stack([data["age"]**i for i in range(2)]) ).fit().rsquared
plt.title('r2 = {}'.format(rs))
plt.show()
plt.savefig('poly2.png')
# In[17]:
#3rd degree polynomial
p = np.poly1d(np.polyfit(data["age"], data["wage"], 3))
t = np.linspace(0, 80, 200)
plt.plot(data["age"], data["wage"], 'o', t, p(t), '-')
rs = sm.OLS(data["wage"],
np.column_stack([data["age"]**i for i in range(3)]) ).fit().rsquared
plt.title('r2 = {}'.format(rs))
plt.show()
plt.savefig('poly3.png')
# In[18]:
#4th degree polynomial
p = np.poly1d(np.polyfit(data["age"], data["wage"], 4))
t = np.linspace(0, 80, 200)
plt.plot(data["age"], data["wage"], 'o', t, p(t), '-')
rs = sm.OLS(data["wage"],
np.column_stack([data["age"]**i for i in range(4)]) ).fit().rsquared
plt.title('r2 = {}'.format(rs))
plt.show()
plt.savefig('poly4.png')
# In[19]:
#5th degree polynomial
p = np.poly1d(np.polyfit(data["age"], data["wage"], 5))
t = np.linspace(0, 80, 200)
plt.plot(data["age"], data["wage"], 'o', t, p(t), '-')
rs = sm.OLS(data["wage"],
np.column_stack([data["age"]**i for i in range(5)]) ).fit().rsquared
plt.title('r2 = {}'.format(rs))
plt.show()
plt.savefig('poly5.png')
# In[20]:
#let's do a third polynomial regression
polynomial_features= PolynomialFeatures(degree=3)
age_p = polynomial_features.fit_transform(data['age'].to_numpy().reshape(-1, 1))
model3 = sm.OLS(data["wage"], sm.add_constant(np.concatenate([data[['education', 'year']].to_numpy(), age_p], axis=1))).fit()
# In[21]:
#check our results
model3.summary(xname=['education', 'year', 'const', 'poly(age, 3)1', 'poly(age, 3)2', 'poly(age, 3)3'])
# In[22]:
#implementing a bspline for age
age_bs = dmatrix("bs(data.age, df=6)",{"data.age": data.age}, return_type='dataframe')
model4 = sm.OLS(data["wage"], pd.concat([age_bs, data[['education', 'year']]], axis=1)).fit()
model4.summary()
# In[23]:
#implementing a natural spline for age
age_ns = dmatrix("cr(data.age, df=6)",{"data.age": data.age}, return_type='dataframe')
model5 = sm.OLS(data["wage"], pd.concat([age_ns, data[['education', 'year']]], axis=1)).fit()
model5.summary()
Before starting analysis using splines, first look at OLS regression with wage as it relates to age, year, and education. We can run the simple code below to look at this relationship.
reg wage age year edu
Stata will return the following output:
Figure 2.1 OLS output for wage ~ age + year + education
To see if a non-linear relationship might be present, kernal density, pnorm, and qnorm plots can assit with this:
predict r, resid
kdensity r, normal
Figure 2.2 Kernel Density Plot
Figure 2.3 Pnorm Plot
Figure 2.4 Qnorm Plot
After looking at these plots we might consider the different relationships that age may have with wage. We can plot the two-way fit between wage and age, our main variables of interest, to compare a basic linear, polynomial, and quadratic fit.
twoway (scatter wage age) (lfit wage age) (fpfit wage age) (qfit wage age)
Figure 2.5 Fitted Plot - Linear (red), polynomial (green), and quadratic (yellow) fit
Based on these plots we might be interested in trying to fit a cubic polynomial plot next.
To create a cubic polynomial in stata we can use the ## command with the age variable. The regression is written as before with the addition of a cubic fit:
reg wage c.age##c.age##c.age year educ
The output in Stata will look like this:
Figure 2.6 Regression with Cubic polynomial for Age
For the piecewise step function, the steps and intercepts in Stata must be determined manually. Based on analysis in R we determined that including 6 groups with 5 cutpoints is best. The below code shows how to generate six age categories and their intercepts.
* generate 6 age variables, one for each bin *
* the age varaible does not have decimels *
generate age1 = (age - 28.33)
replace age1 = 0 if (age >= 28.33)
generate age2 = (age-38.66)
replace age2 = 0 if age <28.33 | age > 38.66
generate age3 = (age- 48.99)
replace age3 = 0 if age <38.66 | age >=48.99
generate age4 = (age - 59.33)
replace age4 = 0 if age <48.99 | age >= 59.33
generate age5 = (age - 69.66)
replace age5= 0 if age < 59.33 | age>=69.66
generate age6 = (age-80)
replace age6 = 0 if age <69.66
* create intercept variables*
generate int1 = 1
replace int1 = 0 if age >= 28.33
generate int2 = 1
replace int2 = 0 if age <28.33 | age > 38.66
generate int3 = 1
replace int3 = 0 if age <38.66 | age >=48.99
generate int4 = 1
replace int4 = 0 if age <48.99 | age >= 59.33
generate int5 = 1
replace int5= 0 if age < 59.33 | age>=69.66
generate int6 = 1
replace int6 = 0 if age <69.66
Using these variables we can then compute a step-wise regression.
regress wage int1 int2 int3 int4 int5 int6 age1 age2 age3 age4 age5 age6 ///
year educ, hascons
Figure 2.7 Step-wise regression for Age with 6 bins
After running the regression we can then use the predicted yhats to graph the results:
predict yhat
twoway (scatter wage age, sort) ///
(line yhat age if age <28.33, sort) ///
(line yhat age if age >=28.33 & age < 38.66, sort) ///
(line yhat age if age >=38.66 & age < 48.99, sort) ///
(line yhat age if age >=48.99 & age<59.33, sort) ///
(line yhat age if age >=59.33 & age<69.66, sort) ///
(line yhat age if age >=69.66, sort), xline(28.33 38.66 48.99 59.33 69.66) // this looks awful
Figure 2.8 Step-wise regression for Age with 6 bins
For the basis spline, we use the command bspline, created by Roger Newson and suggested by Germán Rodríguez at Princeton. To create the spline, we call bspline, setting the x variable to age and then identifying where we would like the knots in the function. For this example I use 3 knots at 35, 50 and 65, however it should be noted that the min and max of the values need to be included in the knots parentheses. I also use a cubic spline, incidated by p(3). The last step in the line of code is the code that generates the splines for inclusions in the regression. Then the regression can be written as below.
bspline, xvar(age) knots(18 35 50 65 80) p(3) gen(_agespt)
regress wage _agespt* year educ, noconstant
Figure 2.9 Basis Spline Regression
To look at the fit for age, we can examine the two-way scatter plot between wage and age using the predicted values of the bivariate regression with splines.
regress wage _agespt*, noconstant
predict agespt
*(option xb assumed; fitted values)
twoway (scatter wage age)(line agespt age, sort), legend(off) ///
title(Basis Spline for Age)
Figure 2.10 Step-wise regression for Age with 6 bins
This further extension is still being coded. Please see the README.md file.
Here is the required libraries for this tutorial.
library(tidyverse) ## This library is for data manipulation.
library(ggplot2) ## This library is for data visualization.
library(gridExtra) ## This library is also for data visualization.
library(splines) ## This library is for spline.
In this example, the author has written two additional functions. All of the code for the user-written functions is in User-written function.R.
wage_age This function is for plotting the scatter plot between age and wage, which including the regression line.plot_kfold, which primary for doing the 5-fold cross validation to select the degree of freedom for basis and natural spline.User-written function.R
## [1] "## Required Libraries"
## [2] "library(tidyverse)"
## [3] "library(ggplot2)"
## [4] ""
## [5] "### First Function: Scatter Plot between Wage and Age"
## [6] "### with a regression line."
## [7] "wage_age <- function(line = FALSE, poly = 1, formula_wage_age = y ~ x){"
## [8] " plot_title <- \"Scatter Plot between Wage and Age\""
## [9] " "
## [10] " if(poly != 1){"
## [11] " plot_title <- paste0(\"Polynomial degree \", poly)"
## [12] " }"
## [13] " "
## [14] " plot_wage_age <- ggplot(data, aes(x = age, y = wage)) + "
## [15] " geom_point(color = \"darkblue\") +"
## [16] " theme_bw() +"
## [17] " labs(title = plot_title, x = \"Age\", y = \"Wage\")"
## [18] " "
## [19] " if(line == TRUE){"
## [20] " plot_wage_age + "
## [21] " geom_smooth(method = \"lm\", formula = formula_wage_age, color = \"yellow\")"
## [22] " } else {"
## [23] " plot_wage_age"
## [24] " }"
## [25] "}"
## [26] ""
## [27] "## Second Function: Function for performing k-fold cross validation."
## [28] "plot_kfold <- function(knots = 1:10, bs = TRUE){"
## [29] " store_MSE <- c()"
## [30] " title_plot <- \"5-fold cross-validate MSE\""
## [31] " "
## [32] " if(bs == TRUE){"
## [33] " title_plot <- paste0(title_plot, \": Basis Spline\")"
## [34] " } else {"
## [35] " title_plot <- paste0(title_plot, \": Natural Spline\")"
## [36] " }"
## [37] " "
## [38] " for(i in knots){"
## [39] " "
## [40] " MSE <- c()"
## [41] " "
## [42] " for(j in 1:5){"
## [43] " "
## [44] " ## Split the data"
## [45] " train_data <- data %>% filter(CV != j)"
## [46] " validate_data <- data %>% filter(CV == j)"
## [47] " "
## [48] " ## Train the model"
## [49] " if(bs == TRUE){"
## [50] " model <- lm(wage ~ bs(age, df = 4 + i) + edu + year, "
## [51] " data = train_data)"
## [52] " } else {"
## [53] " model <- lm(wage ~ ns(age, df = 2 + i) + edu + year, "
## [54] " data = train_data)"
## [55] " }"
## [56] " "
## [57] " ## Store the MSE for each validation set"
## [58] " MSE <- c(MSE, "
## [59] " mean((predict(model, validate_data) - validate_data$wage)^2))"
## [60] " "
## [61] " }"
## [62] " "
## [63] " ## Store the MSE for each knots"
## [64] " store_MSE <- c(store_MSE, mean(MSE))"
## [65] " }"
## [66] " "
## [67] " ## Create a plot"
## [68] " dummy_result <- data.frame(knots = as.factor(knots), store_MSE)"
## [69] " p <- ggplot(dummy_result, aes(x = knots, y = store_MSE, group = 1)) +"
## [70] " geom_line() +"
## [71] " geom_point() +"
## [72] " labs(title = title_plot, x = \"Number of knot\", y = \"MSE\") +"
## [73] " geom_text(aes(label = round(store_MSE,2)), vjust = -1) +"
## [74] " ylim(min(store_MSE) - 3, max(store_MSE) + 3) +"
## [75] " theme_bw()"
## [76] " print(p)"
## [77] "}"
For the simplicity, the author decided to convert Education variable into numerical variable.
| Old Value | New Value |
|---|---|
| 1. < HS Grad | 1 |
| 2. HS Grad | 2 |
| 3. Some College | 3 |
| 4. College Grad | 4 |
| 5. Advanced Degree | 5 |
data <- data %>% mutate(edu = ifelse(education == "1. < HS Grad", 1,
ifelse(education == "2. HS Grad", 2,
ifelse(education == "3. Some College", 3,
ifelse(education == "4. College Grad", 4, 5)))))
First, consider the linear regression. Below is a model, along with the Quantile-Quantile plot.
model_lr <- lm(wage ~ age + edu + year, data = data)
summary(model_lr)
| wage | |||
|---|---|---|---|
| Predictors | Estimates | CI | p |
| (Intercept) | -2142.85 | -3420.48 – -865.21 | 0.001 |
| age | 0.58 | 0.47 – 0.69 | <0.001 |
| edu | 15.92 | 14.85 – 16.98 | <0.001 |
| year | 1.09 | 0.45 – 1.72 | 0.001 |
| Observations | 3000 | ||
| R2 / R2 adjusted | 0.256 / 0.255 | ||
Table 3.2 Regression Model when using age directly.
According to table 3.2, we can conclude that the model is \(Wage = -2142.85 + 0.58(Age) + 15.92(Edu) + 1.09(Year)\). In addition, the \(R^2\) for this model is 0.2555 which is quite low. Hence, I will look at the Quantile-Quantile plot for the residual, one of the linear regression’s assumption, in order to see that whether the model violates the assumption or not.
Figure 3.3 The QQ plot, along with the Kernel density plot of the residual from the linear regression.
Figure 3.4 Scatter plot between Wage and Age
According to the Figure 3.3 (Left), you will notice that the residuals are not normally distributed since there are some datapoints do not lie on the line.
Apart from the Figure 3.3, Figure 3.4 also shows that the relationship between Wage and Age is not a linear; therefore, this tutorial will try different type of relationship between Wage and Age.
The first type that we will consider is the Polynomial Regression. In order to use this regression, you have to use poly() function in lm().
model_poly <- lm(wage ~ poly(age, 3) + edu + year, data = data)
summary(model_poly)
| wage | |||
|---|---|---|---|
| Predictors | Estimates | CI | p |
| (Intercept) | -2330.05 | -3583.92 – -1076.18 | <0.001 |
| age [1st degree] | 369.89 | 300.40 – 439.37 | <0.001 |
| age [2nd degree] | -383.49 | -453.11 – -313.88 | <0.001 |
| age [3rd degree] | 80.58 | 11.22 – 149.95 | 0.023 |
| edu | 15.30 | 14.25 – 16.35 | <0.001 |
| year | 1.19 | 0.57 – 1.82 | <0.001 |
| Observations | 3000 | ||
| R2 / R2 adjusted | 0.285 / 0.283 | ||
Table 3.5 Cubic Polynomial Regression Model
Figure 3.6 Scatter plot between Wage and Age with a line indicate the cubic polynomial regression
According to the table 3.5, the model is \(Wage = -2330.05 + 369.89(Age) - 383.49(Age)^{2} + 80.58(Age)^{3} + 15.30(Edu) + 1.19(Year)\). The corresponding \(R^2\) is 0.2846.
model_cut <- lm(wage ~ cut(age, 4) + edu + year, data = data)
summary(model_cut)
| wage | |||
|---|---|---|---|
| Predictors | Estimates | CI | p |
| (Intercept) | -2492.40 | -3753.90 – -1230.91 | <0.001 |
|
cut(age, 4) [cut(age, 4)(33.5,49]] |
21.34 | 18.18 – 24.49 | <0.001 |
|
cut(age, 4) [cut(age, 4)(49,64.5]] |
19.89 | 16.31 – 23.46 | <0.001 |
|
cut(age, 4) [cut(age, 4)(64.5,80.1]] |
9.20 | 0.60 – 17.80 | 0.036 |
| edu | 15.76 | 14.71 – 16.81 | <0.001 |
| year | 1.27 | 0.64 – 1.90 | <0.001 |
| Observations | 3000 | ||
| R2 / R2 adjusted | 0.277 / 0.275 | ||
Table 3.7 Piecewise-Linear Regression Model
Figure 3.7 Scatter plot between Wage and Age with the step function.
According to the table 3.7, the model is \(Wage = -2492.40 + 21.34I_{33.5 < Age \leq 49} + 19.89I_{49 < Age \leq 64.5} + 9.20I_{64.5 < Age \leq 80.1} + 15.76(Edu) + 1.27(Year)\) where I is an indicator function. The corresponding \(R^2\) is 0.2766.
For both Basis Spline and Natural Spline, the number of knots or the degree of freedom need to be specified. One of the method used for specified is performing K-fold Cross Validation. In this case, K is equal to 5. For both types of spline, the highest degree of polynomial for age is 3.
Figure 3.6 The MSE of the basis spline from performing 5-fold cross validation
The MSE is lowest when the number of knot is equal to 2. Fit the regression with basis spline.
model_basis <- lm(wage ~ bs(age, df = 6) + edu + year, data = data)
summary(model_basis)
##
## Call:
## lm(formula = wage ~ bs(age, df = 6) + edu + year, data = data)
##
## Residuals:
## Min 1Q Median 3Q Max
## -114.896 -19.443 -3.404 14.246 213.360
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -2433.6166 639.8762 -3.803 0.000146 ***
## bs(age, df = 6)1 8.6483 11.0196 0.785 0.432626
## bs(age, df = 6)2 31.0200 6.3549 4.881 1.11e-06 ***
## bs(age, df = 6)3 46.0792 7.3942 6.232 5.26e-10 ***
## bs(age, df = 6)4 32.3948 7.7626 4.173 3.09e-05 ***
## bs(age, df = 6)5 49.6746 12.1826 4.077 4.67e-05 ***
## bs(age, df = 6)6 5.5023 14.3407 0.384 0.701239
## edu 15.3322 0.5364 28.585 < 2e-16 ***
## year 1.2285 0.3191 3.850 0.000120 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 35.3 on 2991 degrees of freedom
## Multiple R-squared: 0.2864, Adjusted R-squared: 0.2845
## F-statistic: 150 on 8 and 2991 DF, p-value: < 2.2e-16
Figure 3.7 Scatter plot between Wage and Age with Regression line. (Basis Spline)
Figure 3.8 The MSE of the natural spline from performing 5-fold cross validation
The MSE is lowest when the number of knot is equal to 4. Fit the regression with natural spline.
model_natural <- lm(wage ~ ns(age, df = 6) + education + year, data = data)
summary(model_natural)
##
## Call:
## lm(formula = wage ~ ns(age, df = 6) + education + year, data = data)
##
## Residuals:
## Min 1Q Median 3Q Max
## -121.403 -19.727 -3.143 14.174 214.340
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -2394.4450 638.1274 -3.752 0.000179 ***
## ns(age, df = 6)1 38.7338 4.6496 8.331 < 2e-16 ***
## ns(age, df = 6)2 46.4652 5.8970 7.879 4.57e-15 ***
## ns(age, df = 6)3 38.1178 5.1218 7.442 1.29e-13 ***
## ns(age, df = 6)4 37.0673 4.8062 7.712 1.67e-14 ***
## ns(age, df = 6)5 48.9899 11.6639 4.200 2.75e-05 ***
## ns(age, df = 6)6 4.3620 8.9214 0.489 0.624922
## education2. HS Grad 11.1264 2.4295 4.580 4.85e-06 ***
## education3. Some College 23.6491 2.5595 9.240 < 2e-16 ***
## education4. College Grad 38.3108 2.5454 15.051 < 2e-16 ***
## education5. Advanced Degree 62.5971 2.7605 22.676 < 2e-16 ***
## year 1.2186 0.3182 3.830 0.000131 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 35.16 on 2988 degrees of freedom
## Multiple R-squared: 0.2927, Adjusted R-squared: 0.2901
## F-statistic: 112.4 on 11 and 2988 DF, p-value: < 2.2e-16
Figure 3.9 Scatter plot between Wage and Age with Regression line. (Natural Spline)
The table below shoed the \(R^2\) for each model.
| Model | Python | STATA | R |
|---|---|---|---|
| Linear Regression | xx | 0.2555 | 0.2555 |
| Cubic Polynomial Regression | xx | 0.2846 | 0.2846 |
| Piecewise Linear Regression | xx | 0.2882 | 0.2766 |
| Basis Spline | xx | 0.9126 | 0.2864 |
| Natural Spline | xx | xx | 0.2927 |
Table 4.1 The coefficient of determination (\(R^2\)) for each model from Python, STATA and R.
Gareth James, Daniela Witten, Trevor Hastie, Robert Tibshirani. (2013). An introduction to statistical learning : with applications in R. New York :Springer,
Data Description: https://rdrr.io/cran/ISLR/man/Wage.html